--- external/eg3d/training/training_loop.py	2023-04-06 03:56:36.237864870 +0000
+++ external_reference/eg3d/training/training_loop.py	2023-04-06 03:41:04.630726517 +0000
@@ -26,9 +26,9 @@
 from torch_utils.ops import grid_sample_gradfix
 
 import legacy
-from metrics import metric_main
-from camera_utils import LookAtPoseSampler
-from training.crosssection_utils import sample_cross_section
+from external.stylegan.metrics import metric_main
+# from camera_utils import LookAtPoseSampler
+# from training.crosssection_utils import sample_cross_section
 
 #----------------------------------------------------------------------------
 
@@ -72,23 +72,33 @@
 
 #----------------------------------------------------------------------------
 
-def save_image_grid(img, fname, drange, grid_size):
+# modified to save RGB image and depth channel separately
+def save_image_grid(image_and_depth, fname, drange, grid_size):
     lo, hi = drange
-    img = np.asarray(img, dtype=np.float32)
-    img = (img - lo) * (255 / (hi - lo))
-    img = np.rint(img).clip(0, 255).astype(np.uint8)
-
-    gw, gh = grid_size
-    _N, C, H, W = img.shape
-    img = img.reshape([gh, gw, C, H, W])
-    img = img.transpose(0, 3, 1, 4, 2)
-    img = img.reshape([gh * H, gw * W, C])
-
-    assert C in [1, 3]
-    if C == 1:
-        PIL.Image.fromarray(img[:, :, 0], 'L').save(fname)
-    if C == 3:
-        PIL.Image.fromarray(img, 'RGB').save(fname)
+
+    depth = image_and_depth[:, 3:4]
+    img = image_and_depth[:, :3]
+    for idx, img in enumerate((image_and_depth[:, :3], image_and_depth[:, 3:4])):
+        if idx == 1: # depth image
+            lo, hi = 0, 1
+        else:
+            lo, hi = drange
+
+        img = np.asarray(img, dtype=np.float32)
+        img = (img - lo) * (255 / (hi - lo))
+        img = np.rint(img).clip(0, 255).astype(np.uint8)
+
+        gw, gh = grid_size
+        _N, C, H, W = img.shape
+        img = img.reshape([gh, gw, C, H, W])
+        img = img.transpose(0, 3, 1, 4, 2)
+        img = img.reshape([gh * H, gw * W, C])
+
+        assert C in [1, 3]
+        if C == 1:
+            PIL.Image.fromarray(img[:, :, 0], 'L').save(fname.replace('.png', '-disp.png'))
+        if C == 3:
+            PIL.Image.fromarray(img, 'RGB').save(fname.replace('.png', '-rgb.png'))
 
 #----------------------------------------------------------------------------
 
@@ -138,6 +148,9 @@
     conv2d_gradfix.enabled = True                       # Improves training speed. # TODO: ENABLE
     grid_sample_gradfix.enabled = False                  # Avoids errors with the augmentation pipe.
 
+    # added to prevent data_loader pin_memory to load to device 0 for every process
+    torch.cuda.set_device(device)
+
     # Load training set.
     if rank == 0:
         print('Loading training set...')
@@ -262,7 +275,9 @@
         # Fetch training data.
         with torch.autograd.profiler.record_function('data_fetch'):
             phase_real_img, phase_real_c = next(training_set_iterator)
-            phase_real_img = (phase_real_img.to(device).to(torch.float32) / 127.5 - 1).split(batch_gpu)
+            # convert rgb img from [0, 255] to [-1, 1]
+            phase_real_img[:, :3] = phase_real_img[:, :3] / 127.5 - 1
+            phase_real_img = (phase_real_img.to(device).to(torch.float32)).split(batch_gpu)
             phase_real_c = phase_real_c.to(device).split(batch_gpu)
             all_gen_z = torch.randn([len(phases) * batch_size, G.z_dim], device=device)
             all_gen_z = [phase_gen_z.split(batch_gpu) for phase_gen_z in all_gen_z.split(batch_size)]
@@ -361,10 +376,10 @@
             out = [G_ema(z=z, c=c, noise_mode='const') for z, c in zip(grid_z, grid_c)]
             images = torch.cat([o['image'].cpu() for o in out]).numpy()
             images_raw = torch.cat([o['image_raw'].cpu() for o in out]).numpy()
-            images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy()
+            # images_depth = -torch.cat([o['image_depth'].cpu() for o in out]).numpy()
             save_image_grid(images, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}.png'), drange=[-1,1], grid_size=grid_size)
             save_image_grid(images_raw, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_raw.png'), drange=[-1,1], grid_size=grid_size)
-            save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size)
+            # save_image_grid(images_depth, os.path.join(run_dir, f'fakes{cur_nimg//1000:06d}_depth.png'), drange=[images_depth.min(), images_depth.max()], grid_size=grid_size)
 
             #--------------------
             # # Log forward-conditioned images
@@ -414,8 +429,10 @@
                 print(run_dir)
                 print('Evaluating metrics...')
             for metric in metrics:
-                result_dict = metric_main.calc_metric(metric=metric, G=snapshot_data['G_ema'],
-                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus, rank=rank, device=device)
+                result_dict = metric_main.calc_metric(
+                    metric=metric, G=snapshot_data['G_ema'],
+                    dataset_kwargs=training_set_kwargs, num_gpus=num_gpus,
+                    rank=rank, device=device, training_mode='triplane')
                 if rank == 0:
                     metric_main.report_metric(result_dict, run_dir=run_dir, snapshot_pkl=snapshot_pkl)
                 stats_metrics.update(result_dict.results)
